import torch
import transformers
import time
from openai import OpenAI
import shlex
import re


class PentestAgent():
    def __init__(
            self,
            llm_model_id,
            llm_model_local,
            temperature,
            top_p,
            container,
            planner_system_prompt,
            planner_user_prompt,
            summarizer_user_prompt,
            summarizer_system_prompt,
            prompt_chaining=False,
            target_text=None,
            timeout_duration=10,
            do_sample=False,
            max_new_tokens=1024,
            new_observation_length_limit=2000,
            print_end_sep="\n⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯\n"
            ):

        self.target_text = target_text
        self.container = container
        self.timeout_duration = timeout_duration

        self.max_new_tokens = max_new_tokens
        self.llm_model_id = llm_model_id
        self.temperature = temperature
        self.top_p = top_p
        self.do_sample = do_sample
        self.llm_model_local = llm_model_local
        self.llm_pipeline = self.create_llm_pipeline()

        self.summarized_history = ""
        self.new_observation = ""
        self.new_observation_length_limit = new_observation_length_limit
        self.prompt_chaining = prompt_chaining

        self.planner_system_prompt = planner_system_prompt
        self.planner_user_prompt = planner_user_prompt
        self.planner_prompts = []
        self.planner_outputs = []

        self.summarizer_system_prompt = summarizer_system_prompt
        self.summarizer_user_prompt = summarizer_user_prompt
        self.summarizer_prompts = []
        self.summarizer_outputs = []

        self.print_end_sep = print_end_sep

    def reset(self):
        self.summarized_history = ""
        self.new_observation = ""

        self.planner_outputs = []
        self.planner_prompts = []

        self.summarizer_outputs = []
        self.summarizer_prompts = []
        
    def plan_and_run_cmd(self, verbose=True):
        planner_output, input_token_count, output_token_count  = self.planner()

        match = re.search(r'<CMD>(.*?)</CMD>', planner_output)

        if match:
            cmd_to_run = match.group(1)
            safe_cmd = shlex.quote(cmd_to_run)

            if self.container.status != 'running':
                self.container.start()

            command_output = self.container.exec_run(f"""timeout {self.timeout_duration}s /bin/bash -c {safe_cmd}""").output.decode('utf-8').strip()
        else:
            print("No command found.")
            cmd_to_run = "*No command*"
            command_output = ""
            
        if not command_output.strip():
            command_output = "*No output.*"
        
        if verbose:
            print(f"Planner output '{planner_output}':")
            print(command_output, end=self.print_end_sep)

        self.new_observation = f"{cmd_to_run}:\n{command_output}"
        if len(self.new_observation) > self.new_observation_length_limit:
            self.new_observation = self.new_observation[:self.new_observation_length_limit] + " *Output truncated*"
            print("New observation truncated")

        return planner_output, cmd_to_run, command_output, input_token_count, output_token_count 

    def create_llm_pipeline(self):
        if self.llm_model_local:

            tokenizer = transformers.AutoTokenizer.from_pretrained(self.llm_model_id)

            model = transformers.AutoModelForCausalLM.from_pretrained( 
                self.llm_model_id,  
                device_map="auto",  
                torch_dtype=torch.bfloat16,  
                trust_remote_code=True,  
                _attn_implementation='eager'
            ) 

            pipeline = transformers.pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
            )
            return pipeline
        elif "gpt" in self.llm_model_id or "o1" in self.llm_model_id:
            return OpenAI()

    def generate_text_local(self, messages):
        prompt = self.llm_pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        input_tokens = self.llm_pipeline.tokenizer(prompt, return_tensors='pt')
        input_token_count = len(input_tokens['input_ids'][0])
        
        outputs = self.llm_pipeline(
            prompt,
            max_new_tokens=self.max_new_tokens,
            # eos_token_id=terminators,
            # pad_token_id = self.llm_pipeline.tokenizer.eos_token_id,
            #do_sample=False,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
        )
        generated_text = outputs[0]["generated_text"][len(prompt):]

        output_tokens = self.llm_pipeline.tokenizer(generated_text, return_tensors='pt')
        output_token_count = len(output_tokens['input_ids'][0])

        eos_token = self.llm_pipeline.tokenizer.eos_token

        return generated_text.split(eos_token)[-1], input_token_count, output_token_count
    
    def generate_text(self, messages):
        if self.llm_model_local:
            generated_text, input_token_count, output_token_count = self.generate_text_local(messages=messages)
        elif "gpt" in self.llm_model_id or "o1" in self.llm_model_id:
            response = self.llm_pipeline.chat.completions.create(
                model=self.llm_model_id,
                messages=messages,
                temperature=self.temperature,
                top_p=self.top_p,
                max_tokens=self.max_new_tokens,
                store=True
            )
            generated_text = response.choices[0].message.content
            input_token_count = response.usage.prompt_tokens
            output_token_count = response.usage.completion_tokens
        return generated_text, input_token_count, output_token_count 

    def planner(self):
        user_prompt = self.planner_user_prompt.format(summarized_history=self.summarized_history)

        user_prompt += self.target_text
        
        if self.prompt_chaining and len(self.planner_outputs) != 0:
            messages = [
                {"role": "system","content": self.planner_system_prompt},
                {"role": "user","content": self.planner_prompts[-1]},
                {"role": "assistant", "content": self.planner_outputs[-1]},
                {"role": "user","content": user_prompt}
            ]
        else:
            messages = [
                {"role": "system","content": self.planner_system_prompt},
                {"role": "user","content": user_prompt}
            ]

        output, input_token_count, output_token_count  = self.generate_text(messages=messages)

        self.planner_prompts.append(user_prompt)
        self.planner_outputs.append(output)

        return output, input_token_count, output_token_count 

    def summarizer(self, verbose=True):
        user_prompt = self.summarizer_user_prompt.format(summarized_history=self.summarized_history, new_observation=self.new_observation)
        
        if self.prompt_chaining and len(self.summarizer_outputs) != 0:
            messages = [
                {"role": "system","content": self.summarizer_system_prompt},
                {"role": "user","content": self.summarizer_prompts[-1]},
                {"role": "assistant", "content": self.summarizer_outputs[-1]},
                {"role": "user","content": user_prompt}
            ]
        else:
            messages = [
                {"role": "system","content": self.summarizer_system_prompt},
                {"role": "user","content": user_prompt}
            ]

        output, input_token_count, output_token_count  = self.generate_text(messages=messages)

        self.summarized_history = output

        self.summarizer_prompts.append(user_prompt)
        self.summarizer_outputs.append(output)

        if verbose:
            print(f"Current summary:\n{self.summarized_history}", end=self.print_end_sep)

        return output, input_token_count, output_token_count 
    
    def download_files(self, urls):
        for url in urls:
            self.container.exec_run(f"""/bin/bash -c 'wget {url} -O {url.split("/")[-1]}'""").output.decode('utf-8').strip()
